โ Workflow
#Uncomment this line to install packages
%pip install lightning gdown
Collecting lightning Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB) Requirement already satisfied: gdown in /usr/local/lib/python3.11/dist-packages (5.2.0) Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning) (6.0.2) Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning) (2025.3.0) Collecting lightning-utilities<2.0,>=0.10.0 (from lightning) Downloading lightning_utilities-0.14.2-py3-none-any.whl.metadata (5.6 kB) Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning) (24.2) Requirement already satisfied: torch<4.0,>=2.1.0 in /usr/local/lib/python3.11/dist-packages (from lightning) (2.6.0+cu124) Collecting torchmetrics<3.0,>=0.7.0 (from lightning) Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB) Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning) (4.67.1) Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning) (4.12.2) Collecting pytorch-lightning (from lightning) Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from gdown) (4.13.3) Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from gdown) (3.18.0) Requirement already satisfied: requests[socks] in /usr/local/lib/python3.11/dist-packages (from gdown) (2.32.3) Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning) (3.11.14) Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning) (75.1.0) Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (3.4.2) Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (3.1.6) Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cufft-cu12==11.2.1.3 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-curand-cu12==10.3.5.147 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (0.6.2) Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (2.21.5) Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (12.4.127) Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning) Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (3.2.0) Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch<4.0,>=2.1.0->lightning) (1.13.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch<4.0,>=2.1.0->lightning) (1.3.0) Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.11/dist-packages (from torchmetrics<3.0,>=0.7.0->lightning) (2.0.2) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->gdown) (2.6) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (3.4.1) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (2.3.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (2025.1.31) Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (1.7.1) Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (2.6.1) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (1.3.2) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (25.3.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (1.5.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (6.2.0) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (0.3.0) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (1.18.3) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch<4.0,>=2.1.0->lightning) (3.0.2) Downloading lightning-2.5.1-py3-none-any.whl (818 kB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 818.9/818.9 kB 14.9 MB/s eta 0:00:00 Downloading lightning_utilities-0.14.2-py3-none-any.whl (28 kB) Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 363.4/363.4 MB 3.4 MB/s eta 0:00:00 Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 13.8/13.8 MB 41.5 MB/s eta 0:00:00 Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 24.6/24.6 MB 22.6 MB/s eta 0:00:00 Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 883.7/883.7 kB 21.2 MB/s eta 0:00:00 Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 664.8/664.8 MB 820.4 kB/s eta 0:00:00 Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 211.5/211.5 MB 6.0 MB/s eta 0:00:00 Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 56.3/56.3 MB 12.4 MB/s eta 0:00:00 Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 127.9/127.9 MB 7.3 MB/s eta 0:00:00 Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 207.5/207.5 MB 5.9 MB/s eta 0:00:00 Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 21.1/21.1 MB 93.9 MB/s eta 0:00:00 Downloading torchmetrics-1.7.0-py3-none-any.whl (960 kB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 960.9/960.9 kB 55.9 MB/s eta 0:00:00 Downloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 823.0/823.0 kB 47.9 MB/s eta 0:00:00 Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning Attempting uninstall: nvidia-nvjitlink-cu12 Found existing installation: nvidia-nvjitlink-cu12 12.5.82 Uninstalling nvidia-nvjitlink-cu12-12.5.82: Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82 Attempting uninstall: nvidia-curand-cu12 Found existing installation: nvidia-curand-cu12 10.3.6.82 Uninstalling nvidia-curand-cu12-10.3.6.82: Successfully uninstalled nvidia-curand-cu12-10.3.6.82 Attempting uninstall: nvidia-cufft-cu12 Found existing installation: nvidia-cufft-cu12 11.2.3.61 Uninstalling nvidia-cufft-cu12-11.2.3.61: Successfully uninstalled nvidia-cufft-cu12-11.2.3.61 Attempting uninstall: nvidia-cuda-runtime-cu12 Found existing installation: nvidia-cuda-runtime-cu12 12.5.82 Uninstalling nvidia-cuda-runtime-cu12-12.5.82: Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82 Attempting uninstall: nvidia-cuda-nvrtc-cu12 Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82 Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82: Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82 Attempting uninstall: nvidia-cuda-cupti-cu12 Found existing installation: nvidia-cuda-cupti-cu12 12.5.82 Uninstalling nvidia-cuda-cupti-cu12-12.5.82: Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82 Attempting uninstall: nvidia-cublas-cu12 Found existing installation: nvidia-cublas-cu12 12.5.3.2 Uninstalling nvidia-cublas-cu12-12.5.3.2: Successfully uninstalled nvidia-cublas-cu12-12.5.3.2 Attempting uninstall: nvidia-cusparse-cu12 Found existing installation: nvidia-cusparse-cu12 12.5.1.3 Uninstalling nvidia-cusparse-cu12-12.5.1.3: Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3 Attempting uninstall: nvidia-cudnn-cu12 Found existing installation: nvidia-cudnn-cu12 9.3.0.75 Uninstalling nvidia-cudnn-cu12-9.3.0.75: Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75 Attempting uninstall: nvidia-cusolver-cu12 Found existing installation: nvidia-cusolver-cu12 11.6.3.83 Uninstalling nvidia-cusolver-cu12-11.6.3.83: Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83 Successfully installed lightning-2.5.1 lightning-utilities-0.14.2 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-lightning-2.5.1 torchmetrics-1.7.0
import os
import shutil
import lightning as L
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
import pandas as pd
import geopandas as gpd
from sklearn.model_selection import train_test_split
from PIL import Image
import pandas as pd
from pathlib import Path
from sklearn.metrics import confusion_matrix
import sklearn
#Download the zipped tree crown data
!gdown 1svN8wVUmgvyQeOgj_NZkQtp7m7ehUEu2
Downloading... From (original): https://drive.google.com/uc?id=1svN8wVUmgvyQeOgj_NZkQtp7m7ehUEu2 From (redirected): https://drive.google.com/uc?id=1svN8wVUmgvyQeOgj_NZkQtp7m7ehUEu2&confirm=t&uuid=b1512dfe-2b6e-4110-9cfe-70496930319d To: /content/qc_crowns.zip 100% 69.0M/69.0M [00:00<00:00, 141MB/s]
#Remove data dir if it already exists
if os.path.exists("data"):
shutil.rmtree("data")
#Unzip the data
!unzip qc_crowns.zip -d data/
#Remove zip file
!rm qc_crowns.zip
Archive: qc_crowns.zip creating: data/clipped_crowns/ inflating: data/clipped_crowns/crown_1025.png inflating: data/clipped_crowns/crown_1050.png inflating: data/clipped_crowns/crown_106.png inflating: data/clipped_crowns/crown_1072.png inflating: data/clipped_crowns/crown_108.png inflating: data/clipped_crowns/crown_1082.png inflating: data/clipped_crowns/crown_1084.png inflating: data/clipped_crowns/crown_1091.png inflating: data/clipped_crowns/crown_1131.png inflating: data/clipped_crowns/crown_1138.png inflating: data/clipped_crowns/crown_1158.png inflating: data/clipped_crowns/crown_1160.png inflating: data/clipped_crowns/crown_119.png inflating: data/clipped_crowns/crown_1211.png inflating: data/clipped_crowns/crown_124.png inflating: data/clipped_crowns/crown_1242.png inflating: data/clipped_crowns/crown_1251.png inflating: data/clipped_crowns/crown_1291.png inflating: data/clipped_crowns/crown_1303.png inflating: data/clipped_crowns/crown_1311.png inflating: data/clipped_crowns/crown_1327.png inflating: data/clipped_crowns/crown_1362.png inflating: data/clipped_crowns/crown_139.png inflating: data/clipped_crowns/crown_1397.png inflating: data/clipped_crowns/crown_1427.png inflating: data/clipped_crowns/crown_1438.png inflating: data/clipped_crowns/crown_1442.png inflating: data/clipped_crowns/crown_1445.png inflating: data/clipped_crowns/crown_1446.png inflating: data/clipped_crowns/crown_1459.png inflating: data/clipped_crowns/crown_1467.png inflating: data/clipped_crowns/crown_1494.png inflating: data/clipped_crowns/crown_1503.png inflating: data/clipped_crowns/crown_1527.png inflating: data/clipped_crowns/crown_1533.png inflating: data/clipped_crowns/crown_1535.png inflating: data/clipped_crowns/crown_1548.png inflating: data/clipped_crowns/crown_1561.png inflating: data/clipped_crowns/crown_1577.png inflating: data/clipped_crowns/crown_1605.png inflating: data/clipped_crowns/crown_1606.png inflating: data/clipped_crowns/crown_1608.png inflating: data/clipped_crowns/crown_1610.png inflating: data/clipped_crowns/crown_1622.png inflating: data/clipped_crowns/crown_1626.png inflating: data/clipped_crowns/crown_1630.png inflating: data/clipped_crowns/crown_1639.png inflating: data/clipped_crowns/crown_1649.png inflating: data/clipped_crowns/crown_1657.png inflating: data/clipped_crowns/crown_1664.png inflating: data/clipped_crowns/crown_1665.png inflating: data/clipped_crowns/crown_1682.png inflating: data/clipped_crowns/crown_1695.png inflating: data/clipped_crowns/crown_1707.png inflating: data/clipped_crowns/crown_1711.png inflating: data/clipped_crowns/crown_1721.png inflating: data/clipped_crowns/crown_173.png inflating: data/clipped_crowns/crown_1736.png inflating: data/clipped_crowns/crown_1771.png inflating: data/clipped_crowns/crown_1785.png inflating: data/clipped_crowns/crown_1790.png inflating: data/clipped_crowns/crown_1797.png inflating: data/clipped_crowns/crown_1806.png inflating: data/clipped_crowns/crown_1812.png inflating: data/clipped_crowns/crown_1831.png inflating: data/clipped_crowns/crown_1854.png inflating: data/clipped_crowns/crown_1866.png inflating: data/clipped_crowns/crown_1870.png inflating: data/clipped_crowns/crown_192.png inflating: data/clipped_crowns/crown_1952.png inflating: data/clipped_crowns/crown_196.png inflating: data/clipped_crowns/crown_1979.png inflating: data/clipped_crowns/crown_200.png inflating: data/clipped_crowns/crown_2016.png inflating: data/clipped_crowns/crown_2020.png inflating: data/clipped_crowns/crown_2022.png inflating: data/clipped_crowns/crown_2026.png inflating: data/clipped_crowns/crown_2079.png inflating: data/clipped_crowns/crown_2089.png inflating: data/clipped_crowns/crown_2091.png inflating: data/clipped_crowns/crown_2097.png inflating: data/clipped_crowns/crown_2103.png inflating: data/clipped_crowns/crown_211.png inflating: data/clipped_crowns/crown_216.png inflating: data/clipped_crowns/crown_2170.png inflating: data/clipped_crowns/crown_2178.png inflating: data/clipped_crowns/crown_2236.png inflating: data/clipped_crowns/crown_2257.png inflating: data/clipped_crowns/crown_2280.png inflating: data/clipped_crowns/crown_2282.png inflating: data/clipped_crowns/crown_2287.png inflating: data/clipped_crowns/crown_2306.png inflating: data/clipped_crowns/crown_2312.png inflating: data/clipped_crowns/crown_2315.png inflating: data/clipped_crowns/crown_233.png inflating: data/clipped_crowns/crown_2333.png inflating: data/clipped_crowns/crown_2334.png inflating: data/clipped_crowns/crown_234.png inflating: data/clipped_crowns/crown_2343.png inflating: data/clipped_crowns/crown_2350.png inflating: data/clipped_crowns/crown_2368.png inflating: data/clipped_crowns/crown_2381.png inflating: data/clipped_crowns/crown_2400.png inflating: data/clipped_crowns/crown_2402.png inflating: data/clipped_crowns/crown_2405.png inflating: data/clipped_crowns/crown_2432.png inflating: data/clipped_crowns/crown_2438.png inflating: data/clipped_crowns/crown_2458.png inflating: data/clipped_crowns/crown_2459.png inflating: data/clipped_crowns/crown_2461.png inflating: data/clipped_crowns/crown_248.png inflating: data/clipped_crowns/crown_2482.png inflating: data/clipped_crowns/crown_2488.png inflating: data/clipped_crowns/crown_2492.png inflating: data/clipped_crowns/crown_2505.png inflating: data/clipped_crowns/crown_2509.png inflating: data/clipped_crowns/crown_2525.png inflating: data/clipped_crowns/crown_2526.png inflating: data/clipped_crowns/crown_2532.png inflating: data/clipped_crowns/crown_2537.png inflating: data/clipped_crowns/crown_2560.png inflating: data/clipped_crowns/crown_2562.png inflating: data/clipped_crowns/crown_2603.png inflating: data/clipped_crowns/crown_2604.png inflating: data/clipped_crowns/crown_2622.png inflating: data/clipped_crowns/crown_2627.png inflating: data/clipped_crowns/crown_2648.png inflating: data/clipped_crowns/crown_2656.png inflating: data/clipped_crowns/crown_2665.png inflating: data/clipped_crowns/crown_2692.png inflating: data/clipped_crowns/crown_2698.png inflating: data/clipped_crowns/crown_2713.png inflating: data/clipped_crowns/crown_2734.png inflating: data/clipped_crowns/crown_2781.png inflating: data/clipped_crowns/crown_281.png inflating: data/clipped_crowns/crown_2830.png inflating: data/clipped_crowns/crown_2831.png inflating: data/clipped_crowns/crown_286.png inflating: data/clipped_crowns/crown_2924.png inflating: data/clipped_crowns/crown_2932.png inflating: data/clipped_crowns/crown_2937.png inflating: data/clipped_crowns/crown_296.png inflating: data/clipped_crowns/crown_2961.png inflating: data/clipped_crowns/crown_2964.png inflating: data/clipped_crowns/crown_297.png inflating: data/clipped_crowns/crown_2987.png inflating: data/clipped_crowns/crown_3027.png inflating: data/clipped_crowns/crown_3064.png inflating: data/clipped_crowns/crown_3087.png inflating: data/clipped_crowns/crown_3096.png inflating: data/clipped_crowns/crown_3118.png inflating: data/clipped_crowns/crown_3119.png inflating: data/clipped_crowns/crown_3164.png inflating: data/clipped_crowns/crown_3180.png inflating: data/clipped_crowns/crown_3190.png inflating: data/clipped_crowns/crown_3208.png inflating: data/clipped_crowns/crown_3213.png inflating: data/clipped_crowns/crown_322.png inflating: data/clipped_crowns/crown_3221.png inflating: data/clipped_crowns/crown_3223.png inflating: data/clipped_crowns/crown_3228.png inflating: data/clipped_crowns/crown_3239.png inflating: data/clipped_crowns/crown_3258.png inflating: data/clipped_crowns/crown_327.png inflating: data/clipped_crowns/crown_3274.png inflating: data/clipped_crowns/crown_3291.png inflating: data/clipped_crowns/crown_3293.png inflating: data/clipped_crowns/crown_3306.png inflating: data/clipped_crowns/crown_3311.png inflating: data/clipped_crowns/crown_3312.png inflating: data/clipped_crowns/crown_333.png inflating: data/clipped_crowns/crown_3344.png inflating: data/clipped_crowns/crown_335.png inflating: data/clipped_crowns/crown_3351.png inflating: data/clipped_crowns/crown_3357.png inflating: data/clipped_crowns/crown_3373.png inflating: data/clipped_crowns/crown_3377.png inflating: data/clipped_crowns/crown_3388.png inflating: data/clipped_crowns/crown_3391.png inflating: data/clipped_crowns/crown_3399.png inflating: data/clipped_crowns/crown_3407.png inflating: data/clipped_crowns/crown_3450.png inflating: data/clipped_crowns/crown_3454.png inflating: data/clipped_crowns/crown_3470.png inflating: data/clipped_crowns/crown_3517.png inflating: data/clipped_crowns/crown_3522.png inflating: data/clipped_crowns/crown_3530.png inflating: data/clipped_crowns/crown_3553.png inflating: data/clipped_crowns/crown_3567.png inflating: data/clipped_crowns/crown_359.png inflating: data/clipped_crowns/crown_3608.png inflating: data/clipped_crowns/crown_3622.png inflating: data/clipped_crowns/crown_3625.png inflating: data/clipped_crowns/crown_3638.png inflating: data/clipped_crowns/crown_3639.png inflating: data/clipped_crowns/crown_3643.png inflating: data/clipped_crowns/crown_3666.png inflating: data/clipped_crowns/crown_3704.png inflating: data/clipped_crowns/crown_3739.png inflating: data/clipped_crowns/crown_3754.png inflating: data/clipped_crowns/crown_3793.png inflating: data/clipped_crowns/crown_380.png inflating: data/clipped_crowns/crown_3808.png inflating: data/clipped_crowns/crown_3814.png inflating: data/clipped_crowns/crown_3827.png inflating: data/clipped_crowns/crown_3832.png inflating: data/clipped_crowns/crown_3833.png inflating: data/clipped_crowns/crown_3858.png inflating: data/clipped_crowns/crown_3885.png inflating: data/clipped_crowns/crown_390.png inflating: data/clipped_crowns/crown_3901.png inflating: data/clipped_crowns/crown_391.png inflating: data/clipped_crowns/crown_3910.png inflating: data/clipped_crowns/crown_3935.png inflating: data/clipped_crowns/crown_3953.png inflating: data/clipped_crowns/crown_3997.png inflating: data/clipped_crowns/crown_3998.png inflating: data/clipped_crowns/crown_4012.png inflating: data/clipped_crowns/crown_4013.png inflating: data/clipped_crowns/crown_4028.png inflating: data/clipped_crowns/crown_403.png inflating: data/clipped_crowns/crown_4068.png inflating: data/clipped_crowns/crown_409.png inflating: data/clipped_crowns/crown_4096.png inflating: data/clipped_crowns/crown_4119.png inflating: data/clipped_crowns/crown_4121.png inflating: data/clipped_crowns/crown_4167.png inflating: data/clipped_crowns/crown_4197.png inflating: data/clipped_crowns/crown_4200.png inflating: data/clipped_crowns/crown_4205.png inflating: data/clipped_crowns/crown_421.png inflating: data/clipped_crowns/crown_4214.png inflating: data/clipped_crowns/crown_4279.png inflating: data/clipped_crowns/crown_4326.png inflating: data/clipped_crowns/crown_434.png inflating: data/clipped_crowns/crown_4344.png inflating: data/clipped_crowns/crown_4389.png inflating: data/clipped_crowns/crown_439.png inflating: data/clipped_crowns/crown_4396.png inflating: data/clipped_crowns/crown_4400.png inflating: data/clipped_crowns/crown_4452.png inflating: data/clipped_crowns/crown_4461.png inflating: data/clipped_crowns/crown_4464.png inflating: data/clipped_crowns/crown_4471.png inflating: data/clipped_crowns/crown_4472.png inflating: data/clipped_crowns/crown_4498.png inflating: data/clipped_crowns/crown_4542.png inflating: data/clipped_crowns/crown_4546.png inflating: data/clipped_crowns/crown_4554.png inflating: data/clipped_crowns/crown_4555.png inflating: data/clipped_crowns/crown_4586.png inflating: data/clipped_crowns/crown_4589.png inflating: data/clipped_crowns/crown_4606.png inflating: data/clipped_crowns/crown_4628.png inflating: data/clipped_crowns/crown_4663.png inflating: data/clipped_crowns/crown_4665.png inflating: data/clipped_crowns/crown_4711.png inflating: data/clipped_crowns/crown_4730.png inflating: data/clipped_crowns/crown_4751.png inflating: data/clipped_crowns/crown_4770.png inflating: data/clipped_crowns/crown_4786.png inflating: data/clipped_crowns/crown_4788.png inflating: data/clipped_crowns/crown_479.png inflating: data/clipped_crowns/crown_4800.png inflating: data/clipped_crowns/crown_4813.png inflating: data/clipped_crowns/crown_4847.png inflating: data/clipped_crowns/crown_4852.png inflating: data/clipped_crowns/crown_4890.png inflating: data/clipped_crowns/crown_4920.png inflating: data/clipped_crowns/crown_495.png inflating: data/clipped_crowns/crown_4962.png inflating: data/clipped_crowns/crown_4964.png inflating: data/clipped_crowns/crown_4976.png inflating: data/clipped_crowns/crown_4986.png inflating: data/clipped_crowns/crown_5034.png inflating: data/clipped_crowns/crown_5077.png inflating: data/clipped_crowns/crown_5118.png inflating: data/clipped_crowns/crown_5163.png inflating: data/clipped_crowns/crown_5188.png inflating: data/clipped_crowns/crown_5228.png inflating: data/clipped_crowns/crown_5244.png inflating: data/clipped_crowns/crown_5250.png inflating: data/clipped_crowns/crown_5280.png inflating: data/clipped_crowns/crown_5284.png inflating: data/clipped_crowns/crown_5290.png inflating: data/clipped_crowns/crown_5340.png inflating: data/clipped_crowns/crown_5341.png inflating: data/clipped_crowns/crown_5349.png inflating: data/clipped_crowns/crown_535.png inflating: data/clipped_crowns/crown_5390.png inflating: data/clipped_crowns/crown_54.png inflating: data/clipped_crowns/crown_540.png inflating: data/clipped_crowns/crown_5435.png inflating: data/clipped_crowns/crown_5438.png inflating: data/clipped_crowns/crown_5450.png inflating: data/clipped_crowns/crown_5452.png inflating: data/clipped_crowns/crown_5481.png inflating: data/clipped_crowns/crown_549.png inflating: data/clipped_crowns/crown_5506.png inflating: data/clipped_crowns/crown_5525.png inflating: data/clipped_crowns/crown_5526.png inflating: data/clipped_crowns/crown_5558.png inflating: data/clipped_crowns/crown_5563.png inflating: data/clipped_crowns/crown_5571.png inflating: data/clipped_crowns/crown_5588.png inflating: data/clipped_crowns/crown_5590.png inflating: data/clipped_crowns/crown_5592.png inflating: data/clipped_crowns/crown_5603.png inflating: data/clipped_crowns/crown_5605.png inflating: data/clipped_crowns/crown_5614.png inflating: data/clipped_crowns/crown_5622.png inflating: data/clipped_crowns/crown_5632.png inflating: data/clipped_crowns/crown_5642.png inflating: data/clipped_crowns/crown_5645.png inflating: data/clipped_crowns/crown_5659.png inflating: data/clipped_crowns/crown_567.png inflating: data/clipped_crowns/crown_5691.png inflating: data/clipped_crowns/crown_5695.png inflating: data/clipped_crowns/crown_57.png inflating: data/clipped_crowns/crown_5732.png inflating: data/clipped_crowns/crown_5742.png inflating: data/clipped_crowns/crown_5744.png inflating: data/clipped_crowns/crown_5770.png inflating: data/clipped_crowns/crown_579.png inflating: data/clipped_crowns/crown_5802.png inflating: data/clipped_crowns/crown_5831.png inflating: data/clipped_crowns/crown_5840.png inflating: data/clipped_crowns/crown_5886.png inflating: data/clipped_crowns/crown_5894.png inflating: data/clipped_crowns/crown_5922.png inflating: data/clipped_crowns/crown_593.png inflating: data/clipped_crowns/crown_5968.png inflating: data/clipped_crowns/crown_5971.png inflating: data/clipped_crowns/crown_5981.png inflating: data/clipped_crowns/crown_5992.png inflating: data/clipped_crowns/crown_6001.png inflating: data/clipped_crowns/crown_6027.png inflating: data/clipped_crowns/crown_6041.png inflating: data/clipped_crowns/crown_6085.png inflating: data/clipped_crowns/crown_6093.png inflating: data/clipped_crowns/crown_6101.png inflating: data/clipped_crowns/crown_6126.png inflating: data/clipped_crowns/crown_6142.png inflating: data/clipped_crowns/crown_6149.png inflating: data/clipped_crowns/crown_6157.png inflating: data/clipped_crowns/crown_6161.png inflating: data/clipped_crowns/crown_6177.png inflating: data/clipped_crowns/crown_6182.png inflating: data/clipped_crowns/crown_6193.png inflating: data/clipped_crowns/crown_6205.png inflating: data/clipped_crowns/crown_6221.png inflating: data/clipped_crowns/crown_6231.png inflating: data/clipped_crowns/crown_6238.png inflating: data/clipped_crowns/crown_6254.png inflating: data/clipped_crowns/crown_6255.png inflating: data/clipped_crowns/crown_6266.png inflating: data/clipped_crowns/crown_6308.png inflating: data/clipped_crowns/crown_6318.png inflating: data/clipped_crowns/crown_6341.png inflating: data/clipped_crowns/crown_6353.png inflating: data/clipped_crowns/crown_637.png inflating: data/clipped_crowns/crown_6372.png inflating: data/clipped_crowns/crown_6397.png inflating: data/clipped_crowns/crown_6426.png inflating: data/clipped_crowns/crown_6437.png inflating: data/clipped_crowns/crown_6441.png inflating: data/clipped_crowns/crown_6471.png inflating: data/clipped_crowns/crown_6477.png inflating: data/clipped_crowns/crown_648.png inflating: data/clipped_crowns/crown_6489.png inflating: data/clipped_crowns/crown_6506.png inflating: data/clipped_crowns/crown_651.png inflating: data/clipped_crowns/crown_6522.png inflating: data/clipped_crowns/crown_6570.png inflating: data/clipped_crowns/crown_6613.png inflating: data/clipped_crowns/crown_6627.png inflating: data/clipped_crowns/crown_6632.png inflating: data/clipped_crowns/crown_6636.png inflating: data/clipped_crowns/crown_6641.png inflating: data/clipped_crowns/crown_6646.png inflating: data/clipped_crowns/crown_668.png inflating: data/clipped_crowns/crown_6698.png inflating: data/clipped_crowns/crown_6720.png inflating: data/clipped_crowns/crown_6737.png inflating: data/clipped_crowns/crown_6743.png inflating: data/clipped_crowns/crown_6772.png inflating: data/clipped_crowns/crown_6776.png inflating: data/clipped_crowns/crown_6792.png inflating: data/clipped_crowns/crown_6797.png inflating: data/clipped_crowns/crown_6800.png inflating: data/clipped_crowns/crown_6811.png inflating: data/clipped_crowns/crown_6820.png inflating: data/clipped_crowns/crown_6822.png inflating: data/clipped_crowns/crown_6832.png inflating: data/clipped_crowns/crown_6838.png inflating: data/clipped_crowns/crown_6839.png inflating: data/clipped_crowns/crown_6850.png inflating: data/clipped_crowns/crown_6856.png inflating: data/clipped_crowns/crown_6865.png inflating: data/clipped_crowns/crown_6883.png inflating: data/clipped_crowns/crown_6886.png inflating: data/clipped_crowns/crown_6888.png inflating: data/clipped_crowns/crown_6903.png inflating: data/clipped_crowns/crown_6911.png inflating: data/clipped_crowns/crown_6913.png inflating: data/clipped_crowns/crown_6915.png inflating: data/clipped_crowns/crown_6919.png inflating: data/clipped_crowns/crown_6923.png inflating: data/clipped_crowns/crown_6924.png inflating: data/clipped_crowns/crown_6925.png inflating: data/clipped_crowns/crown_6934.png inflating: data/clipped_crowns/crown_6939.png inflating: data/clipped_crowns/crown_6946.png inflating: data/clipped_crowns/crown_6957.png inflating: data/clipped_crowns/crown_6967.png inflating: data/clipped_crowns/crown_6972.png inflating: data/clipped_crowns/crown_6976.png inflating: data/clipped_crowns/crown_6979.png inflating: data/clipped_crowns/crown_6985.png inflating: data/clipped_crowns/crown_6989.png inflating: data/clipped_crowns/crown_6991.png inflating: data/clipped_crowns/crown_7000.png inflating: data/clipped_crowns/crown_7003.png inflating: data/clipped_crowns/crown_7008.png inflating: data/clipped_crowns/crown_701.png inflating: data/clipped_crowns/crown_7046.png inflating: data/clipped_crowns/crown_7052.png inflating: data/clipped_crowns/crown_7054.png inflating: data/clipped_crowns/crown_7056.png inflating: data/clipped_crowns/crown_7059.png inflating: data/clipped_crowns/crown_7061.png inflating: data/clipped_crowns/crown_7067.png inflating: data/clipped_crowns/crown_7069.png inflating: data/clipped_crowns/crown_708.png inflating: data/clipped_crowns/crown_7087.png inflating: data/clipped_crowns/crown_7108.png inflating: data/clipped_crowns/crown_7113.png inflating: data/clipped_crowns/crown_7116.png inflating: data/clipped_crowns/crown_7118.png inflating: data/clipped_crowns/crown_7130.png inflating: data/clipped_crowns/crown_7144.png inflating: data/clipped_crowns/crown_7167.png inflating: data/clipped_crowns/crown_7175.png inflating: data/clipped_crowns/crown_7197.png inflating: data/clipped_crowns/crown_7250.png inflating: data/clipped_crowns/crown_7258.png inflating: data/clipped_crowns/crown_7275.png inflating: data/clipped_crowns/crown_7326.png inflating: data/clipped_crowns/crown_7327.png inflating: data/clipped_crowns/crown_7330.png inflating: data/clipped_crowns/crown_7334.png inflating: data/clipped_crowns/crown_7348.png inflating: data/clipped_crowns/crown_7359.png inflating: data/clipped_crowns/crown_7372.png inflating: data/clipped_crowns/crown_7380.png inflating: data/clipped_crowns/crown_7383.png inflating: data/clipped_crowns/crown_7397.png inflating: data/clipped_crowns/crown_7398.png inflating: data/clipped_crowns/crown_7403.png inflating: data/clipped_crowns/crown_7424.png inflating: data/clipped_crowns/crown_7425.png inflating: data/clipped_crowns/crown_7429.png inflating: data/clipped_crowns/crown_7446.png inflating: data/clipped_crowns/crown_7460.png inflating: data/clipped_crowns/crown_7471.png inflating: data/clipped_crowns/crown_7472.png inflating: data/clipped_crowns/crown_7482.png inflating: data/clipped_crowns/crown_7518.png inflating: data/clipped_crowns/crown_7519.png inflating: data/clipped_crowns/crown_7522.png inflating: data/clipped_crowns/crown_7549.png inflating: data/clipped_crowns/crown_7552.png inflating: data/clipped_crowns/crown_7554.png inflating: data/clipped_crowns/crown_7566.png inflating: data/clipped_crowns/crown_7638.png inflating: data/clipped_crowns/crown_7645.png inflating: data/clipped_crowns/crown_7648.png inflating: data/clipped_crowns/crown_7649.png inflating: data/clipped_crowns/crown_7651.png inflating: data/clipped_crowns/crown_7653.png inflating: data/clipped_crowns/crown_7654.png inflating: data/clipped_crowns/crown_7659.png inflating: data/clipped_crowns/crown_7669.png inflating: data/clipped_crowns/crown_7671.png inflating: data/clipped_crowns/crown_7677.png inflating: data/clipped_crowns/crown_7684.png inflating: data/clipped_crowns/crown_7728.png inflating: data/clipped_crowns/crown_7750.png inflating: data/clipped_crowns/crown_7791.png inflating: data/clipped_crowns/crown_7842.png inflating: data/clipped_crowns/crown_7879.png inflating: data/clipped_crowns/crown_7882.png inflating: data/clipped_crowns/crown_7896.png inflating: data/clipped_crowns/crown_7899.png inflating: data/clipped_crowns/crown_7929.png inflating: data/clipped_crowns/crown_7955.png inflating: data/clipped_crowns/crown_7969.png inflating: data/clipped_crowns/crown_7972.png inflating: data/clipped_crowns/crown_7979.png inflating: data/clipped_crowns/crown_8020.png inflating: data/clipped_crowns/crown_8036.png inflating: data/clipped_crowns/crown_8071.png inflating: data/clipped_crowns/crown_8072.png inflating: data/clipped_crowns/crown_8116.png inflating: data/clipped_crowns/crown_8142.png inflating: data/clipped_crowns/crown_8175.png inflating: data/clipped_crowns/crown_8180.png inflating: data/clipped_crowns/crown_819.png inflating: data/clipped_crowns/crown_8219.png inflating: data/clipped_crowns/crown_8232.png inflating: data/clipped_crowns/crown_8243.png inflating: data/clipped_crowns/crown_8259.png inflating: data/clipped_crowns/crown_8267.png inflating: data/clipped_crowns/crown_8268.png inflating: data/clipped_crowns/crown_8308.png inflating: data/clipped_crowns/crown_8310.png inflating: data/clipped_crowns/crown_8316.png inflating: data/clipped_crowns/crown_832.png inflating: data/clipped_crowns/crown_833.png inflating: data/clipped_crowns/crown_8339.png inflating: data/clipped_crowns/crown_8340.png inflating: data/clipped_crowns/crown_8358.png inflating: data/clipped_crowns/crown_8373.png inflating: data/clipped_crowns/crown_8404.png inflating: data/clipped_crowns/crown_841.png inflating: data/clipped_crowns/crown_8413.png inflating: data/clipped_crowns/crown_842.png inflating: data/clipped_crowns/crown_8442.png inflating: data/clipped_crowns/crown_8463.png inflating: data/clipped_crowns/crown_8469.png inflating: data/clipped_crowns/crown_8501.png inflating: data/clipped_crowns/crown_8511.png inflating: data/clipped_crowns/crown_8600.png inflating: data/clipped_crowns/crown_8603.png inflating: data/clipped_crowns/crown_8618.png inflating: data/clipped_crowns/crown_8651.png inflating: data/clipped_crowns/crown_8686.png inflating: data/clipped_crowns/crown_8705.png inflating: data/clipped_crowns/crown_8709.png inflating: data/clipped_crowns/crown_8710.png inflating: data/clipped_crowns/crown_872.png inflating: data/clipped_crowns/crown_8721.png inflating: data/clipped_crowns/crown_8731.png inflating: data/clipped_crowns/crown_8738.png inflating: data/clipped_crowns/crown_8809.png inflating: data/clipped_crowns/crown_8833.png inflating: data/clipped_crowns/crown_8842.png inflating: data/clipped_crowns/crown_885.png inflating: data/clipped_crowns/crown_8852.png inflating: data/clipped_crowns/crown_8869.png inflating: data/clipped_crowns/crown_891.png inflating: data/clipped_crowns/crown_8917.png inflating: data/clipped_crowns/crown_8936.png inflating: data/clipped_crowns/crown_8943.png inflating: data/clipped_crowns/crown_8949.png inflating: data/clipped_crowns/crown_8954.png inflating: data/clipped_crowns/crown_8964.png inflating: data/clipped_crowns/crown_8969.png inflating: data/clipped_crowns/crown_8982.png inflating: data/clipped_crowns/crown_9015.png inflating: data/clipped_crowns/crown_9022.png inflating: data/clipped_crowns/crown_9037.png inflating: data/clipped_crowns/crown_904.png inflating: data/clipped_crowns/crown_9066.png inflating: data/clipped_crowns/crown_9067.png inflating: data/clipped_crowns/crown_9070.png inflating: data/clipped_crowns/crown_9073.png inflating: data/clipped_crowns/crown_9079.png inflating: data/clipped_crowns/crown_9084.png inflating: data/clipped_crowns/crown_9112.png inflating: data/clipped_crowns/crown_915.png inflating: data/clipped_crowns/crown_9160.png inflating: data/clipped_crowns/crown_919.png inflating: data/clipped_crowns/crown_9202.png inflating: data/clipped_crowns/crown_9203.png inflating: data/clipped_crowns/crown_9212.png inflating: data/clipped_crowns/crown_9226.png inflating: data/clipped_crowns/crown_9231.png inflating: data/clipped_crowns/crown_9233.png inflating: data/clipped_crowns/crown_9235.png inflating: data/clipped_crowns/crown_9236.png inflating: data/clipped_crowns/crown_9249.png inflating: data/clipped_crowns/crown_925.png inflating: data/clipped_crowns/crown_9262.png inflating: data/clipped_crowns/crown_9292.png inflating: data/clipped_crowns/crown_9299.png inflating: data/clipped_crowns/crown_9317.png inflating: data/clipped_crowns/crown_9337.png inflating: data/clipped_crowns/crown_937.png inflating: data/clipped_crowns/crown_9371.png inflating: data/clipped_crowns/crown_9399.png inflating: data/clipped_crowns/crown_9405.png inflating: data/clipped_crowns/crown_9421.png inflating: data/clipped_crowns/crown_9435.png inflating: data/clipped_crowns/crown_9494.png inflating: data/clipped_crowns/crown_962.png inflating: data/clipped_crowns/crown_969.png inflating: data/clipped_crowns/crown_986.png inflating: data/clipped_crowns/crown_995.png inflating: data/clipped_crowns/crown_998.png inflating: data/tree_crowns_subset.gpkg
# List files in the current directory
!ls
data sample_data
#Load the crown polygons
crowns_df = gpd.read_file('data/tree_crowns_subset.gpkg')
# Map class labels to binary values
label_mapping = {'coniferous': 0, 'deciduous': 1}
crowns_df['label'] = crowns_df['species_type'].map(label_mapping)
#Set data dir
img_dir = 'data/clipped_crowns'
img_fpaths = list(Path(img_dir).glob("*.png"))
#Convert fpaths ls to data frame
img_df = pd.DataFrame(img_fpaths, columns=['fpath'])
img_df['crown_id'] = img_df['fpath'].apply(lambda x: int(x.stem.split(".")[0].split("_")[1]))
#Join with crowns_df
crowns_df = crowns_df.merge(img_df, on='crown_id', how='left')
crowns_df
| label | common_name | scientific_name | genus | crown_id | species_type | minx | miny | maxx | maxy | geometry | fpath | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | Balsam fir | Abies balsamea | Pinaceae | 8340 | coniferous | 577189.0365 | 5.093486e+06 | 577192.0568 | 5.093488e+06 | MULTIPOLYGON (((577191.446 5093488.217, 577191... | data/clipped_crowns/crown_8340.png |
| 1 | 0 | Balsam fir | Abies balsamea | Pinaceae | 9399 | coniferous | 576957.3289 | 5.093309e+06 | 576960.9351 | 5.093313e+06 | MULTIPOLYGON (((576958.412 5093313.133, 576958... | data/clipped_crowns/crown_9399.png |
| 2 | 0 | Balsam fir | Abies balsamea | Pinaceae | 2458 | coniferous | 577064.1428 | 5.093336e+06 | 577066.9213 | 5.093339e+06 | MULTIPOLYGON (((577066.056 5093338.765, 577065... | data/clipped_crowns/crown_2458.png |
| 3 | 0 | Balsam fir | Abies balsamea | Pinaceae | 2492 | coniferous | 577052.4109 | 5.093352e+06 | 577054.2873 | 5.093355e+06 | MULTIPOLYGON (((577054.098 5093354.535, 577054... | data/clipped_crowns/crown_2492.png |
| 4 | 0 | Balsam fir | Abies balsamea | Pinaceae | 567 | coniferous | 577186.6727 | 5.093215e+06 | 577191.7753 | 5.093218e+06 | MULTIPOLYGON (((577190.923 5093217.595, 577190... | data/clipped_crowns/crown_567.png |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 595 | 1 | Red maple | Acer rubrum | Sapindaceae | 54 | deciduous | 577088.2029 | 5.093114e+06 | 577093.3687 | 5.093119e+06 | MULTIPOLYGON (((577090.925 5093119.305, 577090... | data/clipped_crowns/crown_54.png |
| 596 | 1 | Red maple | Acer rubrum | Sapindaceae | 1327 | deciduous | 577074.5608 | 5.093307e+06 | 577076.5100 | 5.093309e+06 | MULTIPOLYGON (((577076.149 5093308.862, 577076... | data/clipped_crowns/crown_1327.png |
| 597 | 1 | Red maple | Acer rubrum | Sapindaceae | 6126 | deciduous | 577308.0109 | 5.093633e+06 | 577310.8445 | 5.093635e+06 | MULTIPOLYGON (((577310.509 5093634.769, 577310... | data/clipped_crowns/crown_6126.png |
| 598 | 1 | Red maple | Acer rubrum | Sapindaceae | 5284 | deciduous | 577443.5990 | 5.093582e+06 | 577452.3151 | 5.093589e+06 | MULTIPOLYGON (((577448.862 5093588.134, 577448... | data/clipped_crowns/crown_5284.png |
| 599 | 1 | Red maple | Acer rubrum | Sapindaceae | 6506 | deciduous | 577315.9984 | 5.093473e+06 | 577319.4037 | 5.093477e+06 | MULTIPOLYGON (((577318.737 5093475.773, 577318... | data/clipped_crowns/crown_6506.png |
600 rows ร 12 columns
import seaborn as sns
import matplotlib.pyplot as plt
# Create the count plot with 'label'
ax = sns.countplot(data=crowns_df, x='label', hue='label', palette='viridis', legend=False)
# Add a custom legend
legend_labels = {0: 'Coniferous', 1: 'Deciduous'}
handles = [plt.Rectangle((0, 0), 1, 1, color=ax.patches[i].get_facecolor()) for i in range(len(legend_labels))]
plt.legend(handles, legend_labels.values(), title="Tree Type")
# Set labels and title
plt.xlabel('Label')
plt.ylabel('Count')
plt.title('Distribution of Labels')
plt.show()
class TreeCrownDataset(Dataset):
def __init__(self, crowns_df, split, target_res=256, train_augmentations=[]):
self.target_res = target_res
self.split = split
self.crowns_df = crowns_df
self.train_augmentations = train_augmentations
# Create a transform to resize and normalize the crown images
self.transforms = [
transforms.Resize((target_res, target_res)),
transforms.ToTensor(),
]
#Add additional transforms for data augmentation if using train dataset
if self.split == 'train':
self.transforms.extend(self.train_augmentations)
# Build transform pipeline
self.transforms = transforms.Compose(self.transforms)
def __len__(self):
return len(self.crowns_df)
def __getitem__(self, idx):
target_crown = self.crowns_df.iloc[idx]
label = torch.tensor(target_crown['label']).long()
crown_img = Image.open(target_crown['fpath']).convert('RGB')
crown_tensor = self.transforms(crown_img)
crown_id = target_crown['crown_id']
return crown_tensor, label, crown_id
class TreeCrownDataModule(L.LightningDataModule):
def __init__(self, crowns_df, batch_size=32, train_augmentations=[]):
super().__init__()
self.crowns_df = crowns_df
self.batch_size = batch_size
def setup(self, stage=None):
#Split data into three dataframes for train/val/test
train_val_df, self.test_df = train_test_split(self.crowns_df,
test_size=0.15,
random_state=42)
self.train_df, self.val_df = train_test_split(train_val_df,
test_size=0.17,
random_state=42)
#Report dataset sizes
for name, df in [("Train", self.train_df),
("Val", self.val_df),
("Test", self.test_df)]:
print(f"{name} dataset size: {len(df)}",
f"({round(len(df)/len(crowns_df)*100, 0)}%)")
# Instantiate datasets
self.train_dataset = TreeCrownDataset(self.train_df, split='train')
self.val_dataset = TreeCrownDataset(self.val_df, split='val')
self.test_dataset = TreeCrownDataset(self.test_df, split='test')
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
shuffle=False)
def predict_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
shuffle=False
)
#Set the training data augmentations
train_augmentations = [
transforms.RandomHorizontalFlip(),
transforms.RandomRotation([-90, 90]),
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0))
]
# Test the datamodule
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=train_augmentations)
crowns_datamodule.setup()
# Test loading a sample
sample = crowns_datamodule.train_dataset[0]
print(sample[0].shape)
print(sample[1])
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%) torch.Size([3, 256, 256]) tensor(0)
class CNN(L.LightningModule):
def __init__(self, lr, pretrained_weights=True):
super(CNN, self).__init__()
self.model = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained_weights else None) # IMAGENET1K_V2 vs. random init
# Modify the final fc layer of model to output a single value for binary classification
self.model.fc = nn.Linear(self.model.fc.in_features, 1)
#Add sigmoid activation to the end model
self.model = nn.Sequential(self.model, nn.Sigmoid())
self.criterion = nn.BCELoss()
self.lr = lr
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y, _ = batch
y_hat = self(x).squeeze()
loss = self.criterion(y_hat, y.float())
self.log('train_loss', loss, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
x, y, _ = batch
y_hat = self(x).squeeze()
loss = self.criterion(y_hat, y.float())
self.log('val_loss', loss, on_epoch=True, on_step=False)
return loss
def predict_step(self, batch, batch_idx):
x, y, id = batch
y_hat = self(x).squeeze()
return y_hat, y, id
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
#Instantiate the model with 1 class (present/absent)
model = CNN(lr=0.0001)
print(model)
#Try passing some data through the model
batch, labels, ids = next(iter(crowns_datamodule.train_dataloader()))
# Pass batch through the model
y_hat = model(batch)
print("\nCrown IDs:\n", ids)
print("\nImage batch shape:\n", batch.shape)
print("\nOutput tensor shape:\n", y_hat.shape)
#View the predicted class probabilities
print("\nPredicted class probabilities:\n",
y_hat.detach().cpu().numpy().squeeze())
CNN(
(model): Sequential(
(0): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1, bias=True)
)
(1): Sigmoid()
)
(criterion): BCELoss()
)
Crown IDs:
tensor([ 333, 4554, 5438, 9405, 2964, 9037, 3223, 297, 9233, 9073, 9235, 216,
359, 6924, 8943, 8175, 8020, 3814, 8969, 919, 9112, 1494, 7258, 4119,
6522, 200, 7671, 5744, 5605, 2937, 5525, 1535])
Image batch shape:
torch.Size([32, 3, 256, 256])
Output tensor shape:
torch.Size([32, 1])
Predicted class probabilities:
[0.5147107 0.4775849 0.5033135 0.53658676 0.54825157 0.46613103
0.48881835 0.4971995 0.49795172 0.4735152 0.48773417 0.4733628
0.4925578 0.5024685 0.4727704 0.532721 0.53381765 0.53248763
0.51189476 0.4805917 0.47574198 0.49221498 0.50228703 0.47080755
0.493305 0.46142888 0.48300898 0.5016189 0.46455783 0.4659229
0.46562037 0.45953795]
# put together
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[])
crowns_datamodule.setup()
model = CNN(lr=0.0001)
tensorboard_logger = TensorBoardLogger('', name='lightning_logs', version=0)
csv_logger = CSVLogger('', name='logs', version=0)
trainer = L.Trainer(max_epochs=10, logger=[tensorboard_logger, csv_logger], devices=1)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO: GPU available: True (cuda), used: True INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True INFO: TPU available: False, using: 0 TPU cores INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO: HPU available: False, using: 0 HPUs INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
trainer.fit(model, datamodule=crowns_datamodule)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO: | Name | Type | Params | Mode ------------------------------------------------- 0 | model | Sequential | 23.5 M | train 1 | criterion | BCELoss | 0 | train ------------------------------------------------- 23.5 M Trainable params 0 Non-trainable params 23.5 M Total params 94.040 Total estimated model params size (MB) 154 Modules in train mode 0 Modules in eval mode INFO:lightning.pytorch.callbacks.model_summary: | Name | Type | Params | Mode ------------------------------------------------- 0 | model | Sequential | 23.5 M | train 1 | criterion | BCELoss | 0 | train ------------------------------------------------- 23.5 M Trainable params 0 Non-trainable params 23.5 M Total params 94.040 Total estimated model params size (MB) 154 Modules in train mode 0 Modules in eval mode
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (14) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
INFO: `Trainer.fit` stopped: `max_epochs=10` reached. INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
# Read the logs CSV file after training
logs_df = pd.read_csv(csv_logger.log_dir + '/metrics.csv')
logs_df = logs_df.groupby('epoch').mean() # merge the train and valid rows
logs_df['epoch'] = logs_df.index # because "Epoch" gets turned into the index
logs_df.index.name = '' # to remove the name "Epoch" from the index
# Display the logs
print(logs_df)
step train_loss val_loss epoch
0 13.0 0.607735 0.556226 0
1 27.0 0.392539 0.414568 1
2 41.0 0.223743 0.305848 2
3 55.0 0.120550 0.276962 3
4 69.0 0.061415 0.229616 4
5 83.0 0.024479 0.200803 5
6 97.0 0.033815 0.213471 6
7 111.0 0.025451 0.238192 7
8 125.0 0.021924 0.246687 8
9 139.0 0.021966 0.247015 9
#Plot learning curve
plt.figure(figsize=(10, 6))
plt.plot(logs_df['train_loss'], label='Train Loss')
plt.plot(logs_df['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
def calc_test_oa():
#Test the model on the test set
out = trainer.predict(model, datamodule=crowns_datamodule, return_predictions=True)
# Separate predictions and targets from output
pred_class_probs = np.concatenate([batch[0] for batch in out])
obs = np.concatenate([batch[1] for batch in out])
ids = np.concatenate([batch[2] for batch in out])
#Convert to obs-pred dataframe
test_df = pd.DataFrame({'obs': obs, 'pred_class_probs': pred_class_probs, 'crown_id': ids})
#Convert class probabilities to binary predictions
test_df['pred_boolean_class'] = (test_df['pred_class_probs'] > 0.5)
#Convert binary predictions to integers
test_df['pred'] = test_df['pred_boolean_class'].astype(int)
#Add a column for correct/incorrect predictions
test_df['correct'] = test_df['obs'] == test_df['pred']
#Join with crowns_df
test_df = test_df.merge(crowns_df, on='crown_id', how='left')
#Calculate overall accuracy using sklearn
overall_acc = sklearn.metrics.accuracy_score(y_true=test_df['obs'], y_pred=test_df['pred'])
#Check how many crowns were classified correctly
n_correct = len(test_df[test_df['correct'] == True])
print(f"Summary: {n_correct} / {len(test_df)} crowns were classified correctly.")
return overall_acc, test_df
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 77 / 90 crowns were classified correctly. Overall accuracy: 0.86
print(label_mapping)
#Generate a confusion matrix using seaborn
cm = confusion_matrix(y_true=test_df['obs'],
y_pred=test_df['pred'])
#Plot the confusion matrix
classes = ['Coniferous', 'Deciduous']
sns.heatmap(cm, annot=True,
cmap='YlGn',
xticklabels=classes,
yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('Observed')
plt.title('Confusion Matrix')
plt.show()
{'coniferous': 0, 'deciduous': 1}
# Let's view the incorrectly classified crowns
incorrect_df = test_df[test_df['correct'] == False]
#Plot incorrecty classified coniferous/deciduous crowns
for c_type in test_df['species_type'].unique():
print(f"\nIncorrectly classified {c_type} crowns.\n")
# Filter the incorrect crowns by species type
incorrect_type_df = test_df[(test_df['correct'] == False) & (test_df['species_type'] == c_type)]
# Number of images
num_images = len(incorrect_type_df)
# Determine the grid size
grid_size = int(num_images**0.5) + 1
# Create a figure and axes
fig, axes = plt.subplots(grid_size, grid_size, figsize=(15, 15))
# Flatten the axes array for easy iteration
axes = axes.flatten()
# Read the incorrect crown files and plot them
for ax, fpath in zip(axes, incorrect_type_df['fpath']):
img = Image.open(fpath)
ax.imshow(img)
ax.axis('off')
# Hide any remaining empty subplots
for ax in axes[num_images:]:
ax.axis('off')
plt.tight_layout()
plt.show()
Incorrectly classified coniferous crowns.
Incorrectly classified deciduous crowns.
Forget about ML for a second. Imagine you are baking a cookie. You have 3 things you can change about the cookie:
There are 12 possible variations of cookies you can make. One of them will be the most delicious.
To find out which cookie tastes the best, you need to make all variations and assign a score
This is called a hyperparameter tune. Your three hyperparameters are sugar, baking time, cooking temperature.
python make_cookie.py --sugar 'white' --baking_time 15 --temperature 400
python make_cookie.py --sugar 'brown' --baking_time 15 --temperature 400
# TASK 1: Initialize Data Module
# Logging
csv_logger = CSVLogger("logs", name="csv_logs")
tensorboard_logger = TensorBoardLogger("lightning_logs", name="tb_logs")
# TASK 2: Define Model with Tunable Parameters
# Trainer Configuration
trainer = L.Trainer(
max_epochs=10, # Modify number of epochs
logger=[csv_logger, tensorboard_logger],
devices=1 # Adjust based on available hardware
)
# TASK 3: Training
# print result
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO: GPU available: True (cuda), used: True INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True INFO: TPU available: False, using: 0 TPU cores INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO: HPU available: False, using: 0 HPUs INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 80 / 90 crowns were classified correctly. Overall accuracy: 0.89
# Ref
# TASK 1: Initialize Data Module
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[train_augmentation])
crowns_datamodule.setup()
# Logging
csv_logger = CSVLogger('', name='logs', version=1)
tensorboard_logger = TensorBoardLogger('', name='lightning_logs', version=1)
# TASK 2: Define Model with Tunable Parameters
model = CNN(lr=0.01, pretrained_weights=False)
# Trainer Configuration
trainer = L.Trainer(
max_epochs=10, # Modify number of epochs
logger=[csv_logger, tensorboard_logger],
devices=1 # Adjust based on available hardware
)
trainer = L.Trainer(max_epochs=10, logger=[csv_logger, tensorboard_logger], devices=1)
# TASK 3: Training
trainer.fit(model, datamodule=crowns_datamodule)
# print result
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO: GPU available: True (cuda), used: True INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True INFO: TPU available: False, using: 0 TPU cores INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO: HPU available: False, using: 0 HPUs INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry. INFO: GPU available: True (cuda), used: True INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True INFO: TPU available: False, using: 0 TPU cores INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO: HPU available: False, using: 0 HPUs INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs /usr/local/lib/python3.11/dist-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory logs/version_1 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved! /usr/local/lib/python3.11/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory logs/version_1/checkpoints exists and is not empty. INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO: | Name | Type | Params | Mode ------------------------------------------------- 0 | model | Sequential | 23.5 M | train 1 | criterion | BCELoss | 0 | train ------------------------------------------------- 23.5 M Trainable params 0 Non-trainable params 23.5 M Total params 94.040 Total estimated model params size (MB) 154 Modules in train mode 0 Modules in eval mode INFO:lightning.pytorch.callbacks.model_summary: | Name | Type | Params | Mode ------------------------------------------------- 0 | model | Sequential | 23.5 M | train 1 | criterion | BCELoss | 0 | train ------------------------------------------------- 23.5 M Trainable params 0 Non-trainable params 23.5 M Total params 94.040 Total estimated model params size (MB) 154 Modules in train mode 0 Modules in eval mode
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (14) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
INFO: `Trainer.fit` stopped: `max_epochs=10` reached. INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached. INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 78 / 90 crowns were classified correctly. Overall accuracy: 0.87
๐๐ฝโโ๏ธ what combination of parameters produces the best performing model?
The definition of "best" depends on the work you are doing. In general, "best" refers to the lowest loss.
If we run this training script with different hyperparameter combinations, it produces different loss curves
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/
trainer.save_checkpoint(filepath="ckpt/model.ckpt")
model = CNN.load_from_checkpoint("ckpt/model.ckpt", lr=0.01)
model.freeze()
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[])
crowns_datamodule.setup()
test_predictions = trainer.predict(model, datamodule=crowns_datamodule)
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%) Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 78 / 90 crowns were classified correctly. Overall accuracy: 0.87
You can either make all cookies sequentially (which will take you 4.5 hours). Or you can get 12 kitchens and cook them all in parallel, and you'll know in 30 minutes.
If a kitchen is a GPU, then you need 12 GPUs to run each experiment to see which cookie is the best. The power of Lightning is the ability to run sweeps like this on 12 different GPUs (or 1,000 GPUs if you'd like) to get you the best version of a model fast.
Train on GPUs The Trainer will run on all available GPUs by default. Make sure youโre running on a machine with at least one GPU. Thereโs no need to specify any NVIDIA flags as Lightning will do it for you.
from lightning import Trainer
# run on one GPU
trainer = Trainer(accelerator="gpu", devices=1)
# run on multiple GPUs
trainer = Trainer(accelerator="gpu", devices=8)
# run on as many GPUs as available by default
trainer = Trainer(accelerator="auto", devices="auto", strategy="auto")
Train on Slurm Cluster
# train.py
def main(args):
model = CNN(args)
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=4, strategy="ddp")
trainer.fit(model)
if __name__ == "__main__":
args = ... # you can use your CLI parser of choice, or the `LightningCLI` or using config.yaml
# TRAIN
main(args)
%%writefile submit.sh
# (submit.sh)
#!/bin/bash -l
# SLURM SUBMIT SCRIPT
#SBATCH --nodes=4 # This needs to match Trainer(num_nodes=...)
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8 # This needs to match Trainer(devices=...)
#SBATCH --mem=0
#SBATCH --time=0-02:00:00
# might need the latest CUDA
module load python/3.11 NCCL/2.4.7-1-cuda.10.0
# activate conda env
source activate $1
# run script from above
srun python3 train.py
%%!
sbatch submit.sh
Or you can even parallel the baking procedure...

import wandb
wandb.login()
%%html
<iframe src="https://api.wandb.ai/links/ubc-yuwei-cao/ebnspmv1" style="border:none;height:1024px;width:100%">